ARG IMAGE_NAME
FROM ${IMAGE_NAME}:{{ cuda.version.major }}.{{ cuda.version.minor }}-{{ cuda.cudnn.target }}-{{ cuda.os.distro }}{{ cuda.os.version }}{{ cuda.image_tag_suffix }} as base

{% if "x86_64" in cuda %}
FROM base as base-amd64

ENV NV_CUDNN_PACKAGE_VERSION {{ cuda.x86_64.cudnn.version }}
ENV NV_CUDNN_VERSION {{ cuda.x86_64.cudnn.version[:-2] }}

ENV NV_CUDNN_PACKAGE_NAME libcudnn{{ cuda.x86_64.cudnn.version[0] }}
{% if cuda.cudnn.target == "runtime" %}
ENV NV_CUDNN_PACKAGE ${NV_CUDNN_PACKAGE_NAME}=${NV_CUDNN_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE ${NV_CUDNN_PACKAGE_NAME}=${NV_CUDNN_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
ENV NV_CUDNN_PACKAGE_DEV ${NV_CUDNN_PACKAGE_NAME}-dev=${NV_CUDNN_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% endif %}

{% endif %}
{% if "ppc64le" in cuda %}
FROM base as base-ppc64le

ENV NV_CUDNN_PACKAGE_VERSION {{ cuda.ppc64le.cudnn.version }}
ENV NV_CUDNN_VERSION {{ cuda.ppc64le.cudnn.version[:-2] }}

ENV NV_CUDNN_PACKAGE_NAME libcudnn{{ cuda.ppc64le.cudnn.version[0] }}
{% if cuda.cudnn.target == "runtime" %}
ENV NV_CUDNN_PACKAGE ${NV_CUDNN_PACKAGE_NAME}=${NV_CUDNN_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE ${NV_CUDNN_PACKAGE_NAME}=${NV_CUDNN_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
ENV NV_CUDNN_PACKAGE_DEV ${NV_CUDNN_PACKAGE_NAME}-dev=${NV_CUDNN_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% endif %}

{% endif -%}
FROM base-${TARGETARCH}

ARG TARGETARCH

LABEL maintainer "NVIDIA CORPORATION <cudatools@nvidia.com>"

{% if cuda.version.major | int == 8 %}
RUN echo "deb {{ cuda.ml_repo_url }}/${NVARCH} /" > /etc/apt/sources.list.d/nvidia-ml.list

{% endif -%}

{% set cudnn_version = cuda.cudnn.version %}
{% if "version" in cuda.cudnn and "-" in cuda.cudnn.version %}
    {% set cudnn_version = cuda.cudnn.version[:-2] %}
{% endif -%}

ENV CUDNN_VERSION ${NV_CUDNN_VERSION}

LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}"

RUN apt-get update && apt-get install -y --no-install-recommends \
    ${NV_CUDNN_PACKAGE} \
    ${NV_CUDNN_PACKAGE_DEV} \
    && apt-mark hold ${NV_CUDNN_PACKAGE_NAME} && \
    rm -rf /var/lib/apt/lists/*
